深度学习--Pytorch实现分类模型
The following article is from 探索GIS的小蜗牛 Author 卫小澄
在分类问题中,通常标签都为类别,可以用离散值来代替。例如,在图像分类问题中,假设一张图片是
Softmax回归同线性回归一样,也是一个单层神经网络。由于每个输出
既然分类问题需要得到离散的预测输出,一个简单的办法是将输出值
但是这样存在问题:
一方面,由于输出层的输出值的范围不确定,我们难以直观上判断这些值的意义。如果某个样本输出值
, , 分别为0.1,10,0.1,那么说明该样本非常可能为第二类,但是如果另外一个样本的输出值 , , 分别为1000,10,1000,那这个10却表示的是为第二类的概率很低; 另一方面,由于真实标签是离散值,这些离散值与不确定范围的输出值之间的误差难以衡量。
因此需要将每个样本的输出值归一化,即softmax运算:
其中
可以看出
一般的,softmax回归的矢量计算表达式为:
假设样本数为
03交叉熵损失函数
交叉熵(cross entropy)是一个常用的衡量两个概率分布差异的测量函数:
例如,样本
假设训练数据集的样本数为
最小化交叉熵损失函数等价于最大化训练数据集所有标签类别的联合预测概率。
在训练好softmax回归模型后,给定任一样本特征,就可以预测每个输出类别的概率。通常,我们把预测概率最大的类别作为输出类别。如果它与真实类别(标签)一致,说明这次预测是正确的。可以使用准确率(accuracy)来评价模型的表现。它等于正确预测数量与总预测数量之比。
05获取数据此处使用的是使用一个图像内容更加复杂的数据集Fashion-MNIST[2]。去下载。
这里我们会使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
torchvision.utils: 其他的一些有用的方法。
数据展示:
import matplotlib.pyplot as plt
import torch
import torchvision
import numpy as np
import torchvision.transforms as transforms
import time
import sys
def getData():
# train:指定是否为训练集 download:是否下载,如果本地已经有了则不会下载
# (可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop。
mnist_train = torchvision.datasets.FashionMNIST(root='data/', train=True, download=True, transform=transforms.ToTensor())
# 如果无法下载,可https://github.com/zalandoresearch/fashion-mnist
# 放进data/FashionMNIST\raw\train-images-idx3-ubyte.gz
mnist_test = torchvision.datasets.FashionMNIST(root='data/', train=False, download=True, transform=transforms.ToTensor())
return mnist_train, mnist_test
# softmax运算
def softmax(X):
X_exp = X.exp()
partition = X_exp.sum(dim=1, keepdim=True)
return X_exp / partition # 这里应用了广播机制
# softmax模型
def model(X, W, b):
return softmax(torch.mm(X.view((-1, W.shape[0])), W) + b)
# 损失函数
def cross_entropy(y_hat, y):
# 按y中的数字取相应的列
return - torch.log(y_hat.gather(1, y.view(-1, 1)))
# 准确率
def evaluate_accuracy(data_iter, net, params):
acc_sum, n = 0.0, 0
for X, y in data_iter:
acc_sum += (net(X, params[0], params[1]).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return acc_sum / n
def train(net, train_iter, test_iter, loss, num_epochs, batch_size,
params=None, lr=None, optimizer=None):
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
for X, y in train_iter:
y_hat = net(X, params[0], params[1])
l = loss(y_hat, y).sum()
# 梯度清零
if optimizer is not None:
optimizer.zero_grad()
elif params is not None and params[0].grad is not None:
for param in params:
param.grad.data.zero_()
l.backward()
if optimizer is None:
sgd(params, lr, batch_size) # 梯度更新
else:
optimizer.step()
train_l_sum += l.item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
n += y.shape[0]
test_acc = evaluate_accuracy(test_iter, net, params)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))
def predict(net, test_iter, params):
X, y = iter(test_iter).next()
true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels = get_fashion_mnist_labels(net(X, params[0], params[1]).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
show_fashion_mnist(X[0:9], titles[0:9])
def main():
# 超参数
num_epochs, lr = 5, 0.1
batch_size = 256 # 批量大小
num_inputs = 28*28 # 图像的高和宽
num_outputs = 10 # 一共十种类别
train_iter, test_iter = loadData(batch_size)
W, b = init_parameters(num_inputs, num_outputs)
train(model, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)
# 预测
predict(model, test_iter, [W, b])
if __name__ == '__main__':
# getData() # 获取数据
# show() # 展示数据
main() # 自己实现
07结果及预测epoch 1, loss 0.7874, train acc 0.749, test acc 0.791
epoch 2, loss 0.5712, train acc 0.812, test acc 0.813
epoch 3, loss 0.5250, train acc 0.826, test acc 0.820
epoch 4, loss 0.5018, train acc 0.833, test acc 0.824
epoch 5, loss 0.4860, train acc 0.837, test acc 0.827
可以看出测试集上能够达到80%以上的准确率。模型训练完后,便可以进行预测。给定一系列图像(第三行图像输出),比较一下它们的真实标签(第一行文本输出)和模型预测结果(第二行文本输出)。
如需源码,请后台回复 "softmax回归"。有什么问题,可添加微信 "wxid-3ccc".